#include "force_field.h"
#include "comm.h"
#include "memory_cpp.hpp"
#include "timer.h"
#include "cell.h"
#include "listed.hpp"
#include <algorithm>
#ifdef __sw__
extern void force_reset_sw(cellgrid_t *grid);
extern void force_fdotr_compute_sw(cellgrid_t *grid, mdstat_t *stat);
#endif
void force_fdotr_compute(cellgrid_t *grid, mdstat_t *stat){
  FOREACH_CELL(grid, ii, jj, kk, cell){
    for (int i = 0; i < cell->natom; i ++){
      vec<real> xi;
      vecaddv(xi, cell->basis, cell->x[i]);
      stat->virial[0] += xi.x * cell->f[i].x;
      stat->virial[1] += xi.y * cell->f[i].y;
      stat->virial[2] += xi.z * cell->f[i].z;
      stat->virial[3] += xi.x * cell->f[i].y;
      stat->virial[4] += xi.x * cell->f[i].z;
      stat->virial[5] += xi.y * cell->f[i].z;
    }
  }
}
DEF_TIMER(BONDED, "force/bonded")
DEF_TIMER(NONBONDED, "force/nonbonded")
DEF_TIMER(FLONG, "force/long")
DEF_TIMER(FDOTR, "force/fdotr")
void force(ff_t *ff){
  #ifdef __sw__
  force_reset_sw(ff->grid);
  #else
  FOREACH_CELL(ff->grid, ii, jj, kk, cell){
    for (int i = 0; i < cell->natom; i++) {
      cell->f[i].x = cell->f[i].y = cell->f[i].z = 0;
    }
  }
  #endif
  memset(&ff->stat, 0, sizeof(ff->stat));
  for (enhance_t *enh : ff->enhs) {
    enh->pre_force(ff->grid, ff->mpp);
  }

  if (ff->use_fdotr & M_BOND){
    timer_start(BONDED);
    // static vec<real> (*fbak)[CELL_CAP] = (vec<real>(*)[CELL_CAP])esmd::allocate(ff->grid->dim.vol() * CELL_CAP * sizeof(vec<real>), "force backup");
    // compute_listed_forces(ff->grid, NULL, &ff->stat);
    // FOREACH_CELL(ff->grid, cx, cy, cz, cell) {
    //   int off = get_offset_xyz<true>(ff->grid, cx, cy, cz);
    //   for (int i = 0; i < cell->natom; i ++){
    //     fbak[off][i] = cell->f[i];
    //     cell->f[i] = 0;
    //   }
    // }

    ff->bonded(ff->grid, ff->param, &ff->stat);
    //ff->grid->topo->compute(&ff->stat, ff->grid);
    // real maxdiff = 0;
    // FOREACH_LOCAL_CELL(ff->grid, cx, cy, cz, cell) {
    //   int off = get_offset_xyz<true>(ff->grid, cx, cy, cz);
    //   for (int i = 0; i < cell->natom; i ++){
    //     vec<real> fdiff = cell->f[i] - fbak[off][i];
    //     if (fdiff.norm() > maxdiff) {
    //       printf("%d %d %d %d %ld\n", cx, cy, cz, i, cell->tag[i]);
    //       printf("%f %f %f\n", cell->f[i].x, cell->f[i].y, cell->f[i].z);
    //       printf("%f %f %f\n", fbak[off][i].x, fbak[off][i].y, fbak[off][i].z);
    //     }
    //     maxdiff = std::max(fdiff.norm(), maxdiff);
    //   }
    // }
    
    // printf("%f\n", maxdiff);
    //exit(0);
    timer_stop(BONDED);
  }
  if (ff->use_fdotr & M_NONB){
    timer_start(NONBONDED);
    ff->nonbonded(ff->grid, ff->param, &ff->stat);
    timer_stop(NONBONDED);
  }
  if (ff->use_fdotr & M_LONG){
    timer_start(FLONG);
    ff->longrange(ff->grid, ff->longgrid, ff->param, &ff->stat);
    timer_stop(FLONG);
  }
  timer_start(FDOTR);
  #ifdef __sw__
  force_fdotr_compute_sw(ff->grid, &ff->stat);
  #else
  force_fdotr_compute(ff->grid, &ff->stat);
  #endif
  timer_stop(FDOTR);
  if (ff->present & ~ff->use_fdotr & M_BOND){
    timer_start(BONDED);
    ff->bonded(ff->grid, ff->param, &ff->stat);
    timer_stop(BONDED);
  }
  if (ff->present & ~ff->use_fdotr & M_NONB){
    timer_start(NONBONDED);
    ff->nonbonded(ff->grid, ff->param, &ff->stat);
    timer_stop(NONBONDED);
  }
  if (ff->present & ~ff->use_fdotr & M_LONG){
    timer_start(FLONG);
    ff->longrange(ff->grid, ff->longgrid, ff->param, &ff->stat);
    timer_stop(FLONG);
  }
  for (enhance_t *enh : ff->enhs) {
    enh->post_force(ff->grid, ff->mpp);
  }

}
